{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " # CNN Example with PySNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "\n",
    "from pysnn.network import SNNNetwork\n",
    "from pysnn.connection import Conv2d, AdaptiveMaxPool2d\n",
    "from pysnn.neuron import FedeNeuron, Input\n",
    "from pysnn.learning import FedeSTDP\n",
    "from pysnn.utils import conv2d_output_shape\n",
    "from pysnn.datasets import nmnist_train_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Parameters used during simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Architecture\n",
    "n_in = 5\n",
    "n_hidden1 = 10\n",
    "n_hidden2 = 10\n",
    "n_out = 5\n",
    "\n",
    "# Data\n",
    "sample_length = 300\n",
    "num_workers = 0\n",
    "batch_size = 1\n",
    "\n",
    "# Neuronal Dynamics\n",
    "thresh = 0.8\n",
    "v_rest = 0\n",
    "alpha_v = 0.2\n",
    "tau_v = 5\n",
    "alpha_t = 1.0\n",
    "tau_t = 5\n",
    "duration_refrac = 5\n",
    "dt = 1\n",
    "delay = 3\n",
    "n_dynamics = (thresh, v_rest, alpha_v, alpha_t, dt, duration_refrac, tau_v, tau_t)\n",
    "c_dynamics = (batch_size, dt, delay)\n",
    "i_dynamics = (dt, alpha_t, tau_t)\n",
    "\n",
    "# Learning\n",
    "lr = 0.0001\n",
    "w_init = 0.5\n",
    "a = 0.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Network definition\n",
    " The API is mostly the same aas for regular PyTorch. The main differences are that layers are composed of a Neuron and Connection type, and the layer has to be added to the network by calling the add_layer method. Lastly, all objects return both a spike (or activation potential) object and a trace object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(SNNNetwork):\n",
    "    def __init__(self):\n",
    "        super(Network, self).__init__()\n",
    "\n",
    "        # Input\n",
    "        self.input = Input(\n",
    "            (batch_size, 2, 34, 34), *i_dynamics, update_type=\"exponential\"\n",
    "        )\n",
    "\n",
    "        # Layer 1\n",
    "        self.conv1 = Conv2d(2, 4, 5, (34, 34), *c_dynamics, padding=1, stride=1)\n",
    "        self.neuron1 = FedeNeuron((batch_size, 4, 32, 32), *n_dynamics)\n",
    "        self.add_layer(\"conv1\", self.conv1, self.neuron1)\n",
    "\n",
    "        # Layer 2\n",
    "        self.conv2 = Conv2d(4, 8, 5, (32, 32), *c_dynamics, padding=1, stride=2)\n",
    "        self.neuron2 = FedeNeuron((batch_size, 8, 15, 15), *n_dynamics)\n",
    "        self.add_layer(\"conv2\", self.conv2, self.neuron2)\n",
    "\n",
    "        # Layer out\n",
    "        self.conv3 = Conv2d(8, 1, 3, (15, 15), *c_dynamics)\n",
    "        self.neuron3 = FedeNeuron((batch_size, 1, 13, 13), *n_dynamics)\n",
    "        self.add_layer(\"conv3\", self.conv3, self.neuron3)\n",
    "\n",
    "    def forward(self, input):\n",
    "        x, t = self.input(input)\n",
    "\n",
    "        # Layer 1\n",
    "        x, t = self.conv1(x, t)\n",
    "        x, t = self.neuron1(x, t)\n",
    "\n",
    "        # Layer 2\n",
    "        # x = self.pool2(x)\n",
    "        x, t = self.conv2(x, t)\n",
    "        x, t = self.neuron2(x, t)\n",
    "\n",
    "        # Layer out\n",
    "        x, t = self.conv3(x, t)\n",
    "        x, t = self.neuron3(x, t)\n",
    "\n",
    "        return x, t"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Dataset\n",
    " The dataset is the Neuromorphic NMNIST dataset, obtained from the following website: [NMNIST-link](https://www.garrickorchard.com/datasets/n-mnist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = \"nminst\"\n",
    "if os.path.isdir(root):\n",
    "    train_dataset, test_dataset = nmnist_train_test(root)\n",
    "else:\n",
    "    raise NotADirectoryError(\n",
    "        \"Make sure to download the N-MNIST dataset from https://www.garrickorchard.com/datasets/n-mnist and put it in the 'nmnist' folder.\"\n",
    "    )\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers\n",
    ")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Training\n",
    " Training is performed on just three images in order to keep computational time low.\n",
    " \n",
    " Automatically uses GPU if available on system, also shows the possibility of using 16 bit floating point calculations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Network()\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "    net = net.to(torch.float16).cuda()\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "learning_rule = FedeSTDP(net.layer_state_dict(), lr, w_init, a)\n",
    "\n",
    "# Go over three items\n",
    "output = []\n",
    "for _ in range(3):\n",
    "    out = []\n",
    "    batch = next(iter(train_dataloader))\n",
    "    input = batch[0]\n",
    "    for idx in range(input.shape[-1]):\n",
    "        x = input[:, :, :, :, idx].to(device)\n",
    "        y, _ = net(x)\n",
    "        out.append(y)\n",
    "\n",
    "        learning_rule.step()\n",
    "\n",
    "    net.reset_state()\n",
    "    output.append(torch.stack(out, dim=-1))\n",
    "\n",
    "print(output[0].shape)"
   ]
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
